import gym
from gym import spaces
import numpy as np
import math
from gym.utils import seeding
from gym.envs.registration import register

class StochasticContinuousAcrobotEnv(gym.Env):
    """
    A continuous-action version of Acrobot with optional noise.
    The torque is in [-1, +1]. We can add noise to:
      - action (action_noise_scale)
      - angular accelerations (dynamics_noise_scale)
      - the observed state (obs_noise_scale)
    """

    metadata = {
        "render.modes": ["human", "rgb_array"],
        "video.frames_per_second": 15
    }

    def __init__(
        self,
        action_noise_scale=0.0,
        dynamics_noise_scale=0.0,
        obs_noise_scale=0.0
    ):
        """
        :param action_noise_scale: float, std dev of Gaussian added to torque
        :param dynamics_noise_scale: float, std dev of Gaussian added to accel
        :param obs_noise_scale: float, std dev of Gaussian added to observations
        """
        self.max_speed_1 = 4 * math.pi
        self.max_speed_2 = 9 * math.pi
        self.max_torque = 1.0  # continuous torque range is [-1, +1]
        self.dt = 0.2
        self.link_length_1 = 1.0
        self.link_length_2 = 1.0
        self.link_mass_1 = 1.0
        self.link_mass_2 = 1.0
        self.link_com_pos_1 = 0.5
        self.link_com_pos_2 = 0.5
        self.link_moi = 1.0
        self.goal_height = 1.0  # must get the end-effector above 1.0

        # Stochastic parameters
        self.action_noise_scale = action_noise_scale
        self.dynamics_noise_scale = dynamics_noise_scale
        self.obs_noise_scale = obs_noise_scale

        high = np.array([1., 1., 1., 1., self.max_speed_1, self.max_speed_2], dtype=np.float32)
        self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)

        self.viewer = None
        self.state = None
        self.seed()
        self.reset()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def _wrap_angle(self, x):
        # keep angles in [-pi, pi]
        return ((x + np.pi) % (2 * np.pi)) - np.pi

    def step(self, action):
        # 1) Add action noise if requested
        raw_action = float(action[0])
        noisy_action = raw_action + self.np_random.normal(0.0, self.action_noise_scale)
        # clip to [-1, 1]
        act = np.clip(noisy_action, -self.max_torque, self.max_torque)

        # Current state
        theta1, theta2, thetaDot1, thetaDot2 = self.state

        # Equations of motion
        m1 = self.link_mass_1
        m2 = self.link_mass_2
        l1 = self.link_length_1
        l2 = self.link_length_2
        lc1 = self.link_com_pos_1
        lc2 = self.link_com_pos_2
        I1 = self.link_moi
        I2 = self.link_moi
        # g = 9.8
        g = 30.0

        d1 = (m1 * lc1**2
              + m2 * (l1**2 + lc2**2 + 2 * l1 * lc2 * math.cos(theta2))
              + I1 + I2)
        d2 = m2 * (lc2**2 + l1 * lc2 * math.cos(theta2)) + I2

        phi2 = m2 * lc2 * g * math.cos(theta1 + theta2 - math.pi / 2)
        phi1 = (-m2*l1*lc2*(thetaDot2**2)*math.sin(theta2)*2
                - m2*l1*lc2*thetaDot2*thetaDot1*math.sin(theta2)
                + (m1*lc1 + m2*l1)*g * math.cos(theta1 - math.pi/2)
                + phi2)

        accel2 = (
            act
            + d2 / d1 * phi1
            - m2 * l1 * lc2 * thetaDot1**2 * math.sin(theta2)
            - phi2
        ) / (m2*lc2**2 + I2 - d2**2 / d1)

        accel1 = -(d2 * accel2 + phi1) / d1

        # 2) Add some noise to the accelerations if desired
        if self.dynamics_noise_scale > 0.0:
            accel1 += self.np_random.normal(0.0, self.dynamics_noise_scale)
            accel2 += self.np_random.normal(0.0, self.dynamics_noise_scale)

        # Integrate
        thetaDot1 += accel1 * self.dt
        thetaDot2 += accel2 * self.dt
        theta1 += thetaDot1 * self.dt
        theta2 += thetaDot2 * self.dt

        # wrap speeds and angles
        thetaDot1 = np.clip(thetaDot1, -self.max_speed_1, self.max_speed_1)
        thetaDot2 = np.clip(thetaDot2, -self.max_speed_2, self.max_speed_2)
        theta1 = self._wrap_angle(theta1)
        theta2 = self._wrap_angle(theta2)

        self.state = np.array([theta1, theta2, thetaDot1, thetaDot2], dtype=np.float32)

        # compute the usual reward
        done = self._terminal()
        original_reward = -1.0  # typical shaping: -1 per step until done

        # 3) Shift & scale so that reward is never zero or negative
        # e.g. we shift by +2 => -1 => +1, then clamp min to 0.01 => strictly positive
        shifted_reward = original_reward + 2.0      # so -1 => +1
        final_reward = max(0.01, shifted_reward)    # clamp to 0.01

        # 4) Potentially noisy observation
        obs = self._get_obs()
        if self.obs_noise_scale > 0.0:
            noise = self.np_random.normal(0.0, self.obs_noise_scale, size=obs.shape).astype(np.float32)
            obs += noise

        return obs, final_reward, done, {}

    def _terminal(self):
        theta1, theta2, _, _ = self.state
        end_effector_y = (
            - self.link_length_1 * math.cos(theta1)
            - self.link_length_2 * math.cos(theta1 + theta2)
        )
        return bool(end_effector_y > self.goal_height)

    def _get_obs(self):
        theta1, theta2, thetaDot1, thetaDot2 = self.state
        return np.array([
            np.cos(theta1),
            np.sin(theta1),
            np.cos(theta2),
            np.sin(theta2),
            thetaDot1,
            thetaDot2
        ], dtype=np.float32)

    def reset(self):
        high = np.array([np.pi, np.pi, 1.0, 1.0], dtype=np.float32)
        low  = np.array([-np.pi, -np.pi, -1.0, -1.0], dtype=np.float32)
        rnd = self.np_random.uniform(low, high)
        self.state = np.array(rnd, dtype=np.float32)
        obs = self._get_obs()

        if self.obs_noise_scale > 0.0:
            noise = self.np_random.normal(0.0, self.obs_noise_scale, size=obs.shape).astype(np.float32)
            obs += noise

        return obs

    def render(self, mode='human'):
        pass

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None

register(
    id="StochasticContinuousAcrobot-v0",
    entry_point="continuous_acrobot:StochasticContinuousAcrobotEnv",
    max_episode_steps=500,
)
